/***************************************************************************
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
***************************************************************************/
#include "stdafx.h"
#include "ComputeParallelReduction.h"
#include "ParallelReductionType.h"

namespace Falcor
{
    static const char kShaderFile[] = "Utils/Algorithm/ParallelReduction.cs.slang";
    static const char kShaderModel[] = "6_0";

    ComputeParallelReduction::SharedPtr ComputeParallelReduction::create()
    {
        SharedPtr ptr = SharedPtr(new ComputeParallelReduction());
        return ptr->init() ? ptr : nullptr;
    }

    bool ComputeParallelReduction::init()
    {
        // Create the programs.
        // Set defines to avoid compiler warnings about undefined macros. Proper values will be assigned at runtime.
        Program::DefineList defines = { { "FORMAT_CHANNELS", "1" }, { "FORMAT_TYPE", "1" } };
        if (!(mpInitialProgram = ComputeProgram::createFromFile(kShaderFile, "initialPass", defines, Shader::CompilerFlags::None, kShaderModel))) return false;
        if (!(mpFinalProgram = ComputeProgram::createFromFile(kShaderFile, "finalPass", defines, Shader::CompilerFlags::None, kShaderModel))) return false;
        if (!(mpVars = ComputeVars::create(mpInitialProgram.get()))) return false;

        // Check assumptions on thread group sizes. The initial pass is a 2D dispatch, the final pass a 1D.
        assert(mpInitialProgram->getReflector()->getThreadGroupSize().z == 1);
        assert(mpFinalProgram->getReflector()->getThreadGroupSize().y == 1 && mpFinalProgram->getReflector()->getThreadGroupSize().z == 1);

        mpState = ComputeState::create();

        return true;
    }

    bool ComputeParallelReduction::allocate(uint32_t elementCount)
    {
        if (mpBuffers[0] == nullptr || mpBuffers[0]->getElementCount() < elementCount)
        {
            // Buffer 0 has one element per tile.
            mpBuffers[0] = TypedBuffer<glm::uvec4>::create(elementCount);
            if (!mpBuffers[0]) return false;

            // Buffer 1 has one element per N elements in buffer 0.
            const uint32_t numElem1 = div_round_up(elementCount, mpFinalProgram->getReflector()->getThreadGroupSize().x);
            if (mpBuffers[1] == nullptr || mpBuffers[1]->getElementCount() < numElem1)
            {
                mpBuffers[1] = TypedBuffer<glm::uvec4>::create(numElem1);
                if (!mpBuffers[1]) return false;
            }
        }
        return true;
    }

    template<typename T>
    bool ComputeParallelReduction::execute(RenderContext* pRenderContext, const Texture::SharedPtr& pInput, Type operation, T* pResult, Buffer::SharedPtr pResultBuffer, uint64_t resultOffset)
    {
        PROFILE("ComputeParallelReduction::execute");

        // Check texture array/mip/sample count.
        if (pInput->getArraySize() != 1 || pInput->getMipCount() != 1 || pInput->getSampleCount() != 1)
        {
            logError("ComputeParallelReduction::execute() - Input texture is unsupported. Aborting.");
            return false;
        }

        // Check texture format.
        uint32_t formatType = FORMAT_TYPE_UNKNOWN;
        switch (getFormatType(pInput->getFormat()))
        {
        case FormatType::Float:
        case FormatType::Unorm:
        case FormatType::Snorm:
            formatType = FORMAT_TYPE_FLOAT;
            break;
        case FormatType::Sint:
            formatType = FORMAT_TYPE_SINT;
            break;
        case FormatType::Uint:
            formatType = FORMAT_TYPE_UINT;
            break;
        default:
            logError("ComputeParallelReduction::execute() - Input texture format unsupported. Aborting.");
            return false;
        }

        // Check that reduction type T is compatible with the resource format.
        if (sizeof(T::value_type) != 4 ||     // The shader is written for 32-bit types
            (formatType == FORMAT_TYPE_FLOAT && !std::is_floating_point<T::value_type>::value) ||
            (formatType == FORMAT_TYPE_SINT && (!std::is_integral<T::value_type>::value || !std::is_signed<T::value_type>::value)) ||
            (formatType == FORMAT_TYPE_UINT && (!std::is_integral<T::value_type>::value || !std::is_unsigned<T::value_type>::value)))
        {
            logError("ComputeParallelReduction::execute() - Template type T is not compatible with resource format. Aborting.");
            return false;
        }

        // Allocate intermediate buffers if needed.
        const glm::uvec2 resolution = glm::uvec2(pInput->getWidth(), pInput->getHeight());
        assert(resolution.x > 0 && resolution.y > 0);

        const glm::uvec2 numTiles = div_round_up(resolution, glm::uvec2(mpInitialProgram->getReflector()->getThreadGroupSize()));
        if (!allocate(numTiles.x * numTiles.y))
        {
            logError("ComputeParallelReduction::execute() - Failed to allocate intermediate buffers. Aborting.");
            return false;
        }

        assert(mpBuffers[0]);
        assert(mpBuffers[1]);

        // Configure program.
        const uint32_t channelCount = getFormatChannelCount(pInput->getFormat());
        assert(channelCount >= 1 && channelCount <= 4);
        mpInitialProgram->addDefine("FORMAT_CHANNELS", std::to_string(channelCount));
        mpFinalProgram->addDefine("FORMAT_CHANNELS", std::to_string(channelCount));

        mpInitialProgram->addDefine("FORMAT_TYPE", std::to_string(formatType));
        mpFinalProgram->addDefine("FORMAT_TYPE", std::to_string(formatType));

        // Initial pass: Reduction over tiles of pixels in input texture.
        mpVars["PerFrameCB"]["gResolution"] = resolution;
        mpVars["PerFrameCB"]["gNumTiles"] = numTiles;
        mpVars["gInput"] = pInput;
        mpVars->setTypedBuffer("gResult", mpBuffers[0]);

        mpState->setProgram(mpInitialProgram);
        glm::uvec3 numGroups = div_round_up(glm::uvec3(resolution.x, resolution.y, 1), mpInitialProgram->getReflector()->getThreadGroupSize());
        pRenderContext->dispatch(mpState.get(), mpVars.get(), numGroups);

        // Final pass(es): Reduction by a factor N for each pass.
        uint elems = numTiles.x * numTiles.y;
        uint inputsBufferIndex = 0;

        while (elems > 1)
        {
            mpVars["PerFrameCB"]["gElems"] = elems;
            mpVars->setTypedBuffer("gInputBuffer", mpBuffers[inputsBufferIndex]);
            mpVars->setTypedBuffer("gResult", mpBuffers[1 - inputsBufferIndex]);

            mpState->setProgram(mpFinalProgram);
            uint32_t numGroups = div_round_up(elems, mpFinalProgram->getReflector()->getThreadGroupSize().x);
            pRenderContext->dispatch(mpState.get(), mpVars.get(), { numGroups, 1, 1 });

            inputsBufferIndex = 1 - inputsBufferIndex;
            elems = numGroups;
        }

        // Copy the result to GPU buffer.
        if (pResultBuffer)
        {
            if (resultOffset + 16 > pResultBuffer->getSize())
            {
                logError("ComputeParallelReduction::execute() - Results buffer is too small. Aborting.");
                return false;
            }

            pRenderContext->copyBufferRegion(pResultBuffer.get(), resultOffset, mpBuffers[inputsBufferIndex].get(), 0, 16);
        }

        // Read back the result to the CPU.
        if (pResult)
        {
            const T* pBuf = static_cast<const T*>(mpBuffers[inputsBufferIndex]->map(Buffer::MapType::Read));
            assert(pBuf);
            *pResult = *pBuf;
            mpBuffers[inputsBufferIndex]->unmap();
        }

        return true;
    }

    // Explicit template instantiation of the supported types.
    template dlldecl bool ComputeParallelReduction::execute<glm::vec4>(RenderContext* pRenderContext, const Texture::SharedPtr& pInput, Type operation, glm::vec4* pResult, Buffer::SharedPtr pResultBuffer, uint64_t resultOffset);
    template dlldecl bool ComputeParallelReduction::execute<glm::ivec4>(RenderContext* pRenderContext, const Texture::SharedPtr& pInput, Type operation, glm::ivec4* pResult, Buffer::SharedPtr pResultBuffer, uint64_t resultOffset);
    template dlldecl bool ComputeParallelReduction::execute<glm::uvec4>(RenderContext* pRenderContext, const Texture::SharedPtr& pInput, Type operation, glm::uvec4* pResult, Buffer::SharedPtr pResultBuffer, uint64_t resultOffset);
}
